package it.uniroma2.dtk.dt.subpath;

import it.uniroma2.dtk.common.Spectrum;
import it.uniroma2.dtk.dt.DT;
import it.uniroma2.util.math.ArrayMath;
import it.uniroma2.util.math.MatrixUtils;
import it.uniroma2.util.tree.Tree;
import it.uniroma2.util.vector.RandomVectorGenerator;
import it.uniroma2.util.vector.VectorComposer;
import it.uniroma2.util.vector.VectorProvider;

/**
 * @author Lorenzo Dell'Arciprete
 * Implementation of DT interface for the Subpath Tree Kernel presented in:
 * 		Kimura, D., Kuboyama, T., Shibuya, T., and Kashima, H.
 * 		A subpath kernel for rooted unordered trees. In PAKDD, 2011.
 * This class abstracts from the actual ideal composition function implementation.
 */
public abstract class SubPathAbstractDT implements DT {

	protected VectorProvider vectorProvider;	// RandomVectorGenerator is the default implementation.
	
	protected int vectorSize;
	protected boolean usePos;
	protected double lambdaSq = 1;	// Square root of lambda parameter
	
	/**
	 * @param randomOffset - random seed for vector generation and composition
	 * @param vectorsSize - the dimension of the desired vector space
	 * @param usePos - if true, POS information will be considered for the tree node labels
	 * @param lambda - the value of the lambda decaying factor
	 * @throws Exception
	 */
	public SubPathAbstractDT(int randomOffset, int vectorsSize, boolean usePos, double lambda) throws Exception {
		vectorProvider = new RandomVectorGenerator(vectorsSize, randomOffset);
		this.vectorSize = vectorsSize;
		this.usePos = usePos;
		lambdaSq = Math.sqrt(lambda);
	}
	
	public SubPathAbstractDT(int randomOffset, int vectorsSize, boolean usePos) throws Exception {
		this(randomOffset, vectorsSize, usePos, 1);
	}
	
	public SubPathAbstractDT(int randomOffset, int vectorsSize, double lambda) throws Exception {
		this(randomOffset, vectorsSize, false, lambda);
	}
	
	/**
	 * The vector composition function
	 */
	public abstract double[] op(double[] v1, double[] v2) throws Exception;
	
	public int getVectorSize() {return vectorSize;}
	
	public VectorProvider getVectorProvider() {return vectorProvider;}
	
	public double[] dt(Tree x) {
		//The Spectrum object will collect the values of s(n) during the course of its computation
		Spectrum result = new Spectrum();
		result.setVector(MatrixUtils.uniformVector(vectorSize, 0));
		try {
			sRecursive(x, result);
		}
		catch(Exception e) {
			e.printStackTrace();
		}
		return result.getVector();
	}
	
	/**
	 * Recursive computation of function s(n) for the root of the input tree.
	 * To save time and space, the computed s(n) is directly added to the final sum.
	 * @param node - the input tree or, equivalently, its root node
	 * @param sum - the object collecting the sum of s(n) for each node in the tree
	 * @return s(node)
	 * @throws Exception
	 */
	protected double[] sRecursive(Tree node, Spectrum sum) throws Exception {
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		//Recursive computation of s(n).
		if (!node.isTerminal()) {
			for (int i=0; i<node.getChildren().size(); i++) {
				Tree child = node.getChildren().get(i);
				result = VectorComposer.sum(result, sRecursive(child, sum));
			}
			result = op(getLabelVector(node), result);
		}
		result = ArrayMath.scalardot(lambdaSq, VectorComposer.sum(getLabelVector(node), result));
		//Adding s(node) to the final sum
		sum.setVector(VectorComposer.sum(sum.getVector(), result));
		return result;
	}
	
	public double[] dtf(Tree x) {
		//This is only valid if Tree is a path 
		//Parameter lambda is also considered. 
		double[] result = MatrixUtils.uniformVector(vectorSize, 0);
		try {
			if (x.getChildren().size() == 1) {
				//Recursive computation preserves the right composition order (from end to start of the path)
				result = ArrayMath.scalardot(1/lambdaSq, op(getLabelVector(x), dtf(x.getChildren().get(0))));
			}
			else {
				throw new Exception("Tree fragments for the Subpath Kernel MUST be paths: "+x.toPennTree());
			}
		}
		catch(Exception e) {
			e.printStackTrace();
		}
		return result;
	}
	
	public double[] getLabelVector(Tree node) throws Exception {
		if (usePos) return vectorProvider.getVector(node.getUsePosLabel());
		else return vectorProvider.getVector(node.getRootLabel());
	}
	
	public double getLambda() {
		return lambdaSq*lambdaSq;
	}
	
	public void setLambda(double lambda) {
		lambdaSq = Math.sqrt(lambda);
	}

	public boolean isUsePos() {
		return usePos;
	}

	public void setUsePos(boolean usePos) {
		this.usePos = usePos;
	}

}
